import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

"""Augmenting on the GPU
One of the key concerns when it comes to training a deep learning model is
avoiding bottlenecks in your training pipeline. Well, that's not quite true
— there will always be a bottleneck. The trick is to make sure the bottleneck
is at the resource that's the most expensive or difficult to upgrade, and that
your usage of that resource isn't wasteful.

Some common places to see bottlenecks are as follows:

  - In the data-loading pipeline, either in raw I/O or in decompressing data
    once it's in RAM. We addressed this with our `diskcache` library usage.
  - In CPU preprocessing of the loaded data. This is often data normalization
    or augmentation.
  - In the training loop on the GPU. This is typically where we want our
    bottleneck to be, since total deep learning system costs for GPUs are 
    usually higher than for storage or CPU.
  - Less commonly, the bottleneck can sometimes be the memory bandwidth between
    CPU and GPU. This implies that the GPU isn’t doing much work compared to
    the data size that's being sent in.

Since GPUs can be 50 times faster than CPUs when working on tasks that fit GPUs
well, it often makes sense to move those tasks to the GPU from the CPU in cases
where CPU usage is becoming high. This is especially true if the data gets
expanded during this processing; by moving the smaller input to the GPU first,
the expanded data is kept local to the GPU, and less memory bandwidth is used.

In our case, we're going to move data augmentation to the GPU. This will keep
our CPU usage light, and the GPU will easily be able to accommodate the 
additional workload. Far better to have the GPU busy with a small bit of extra
work than idle waiting for the CPU to struggle through the augmentation process.
"""


class SegmentationAugmentation(nn.Module):
    def __init__(self, flip=None, offset=None, scale=None, rotate=None, noise=None):
        super().__init__()

        self.flip = flip
        self.offset = offset
        self.scale = scale
        self.rotate = rotate
        self.noise = noise

    def forward(self, input_g, label_g):
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)
        transform_t = transform_t.to(input_g.device, torch.float32)
        affine_t = F.affine_grid(
            transform_t[:, :2], input_g.size(), align_corners=False
        )

        augmented_input_g = F.grid_sample(
            input_g, affine_t, padding_mode="border", align_corners=False
        )
        augmented_label_g = F.grid_sample(
            label_g.to(torch.float32),
            affine_t,
            padding_mode="border",
            align_corners=False,
        )

        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g += noise_t

        return augmented_input_g, augmented_label_g > 0.5

    def _build2dTransformMatrix(self):
        # Create a 3x3 matrix, but we will drop the last row later
        transform_t = torch.eye(3)

        # Again, we're augmenting 2D data here
        for i in range(2):
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i, i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = random.random() * 2 - 1
                transform_t[2, i] = offset_float * random_float

            if self.scale:
                scale_float = self.scale
                random_float = random.random() * 2 - 1
                transform_t[i, i] *= 1.0 + scale_float * random_float

        if self.rotate:
            angle_rad = random.random() * math.pi * 2
            s = math.sin(angle_rad)
            c = math.cos(angle_rad)

            # Rotation matrix for the 2D rotation by the random angle in the
            # first two dimensions
            rotation_t = torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]])

            # Applies the rotation to the transformation matrix using the
            # Python matrix multiplication operator
            transform_t @= rotation_t

        return transform_t
